{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Track an experiment while training a Pytorch model locally or in your notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Introductions\n",
    "\n",
    "This notebook shows how you can use the SageMakerCore SDK to track a Machine Learning experiment using a Pytorch model trained locally.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment\n",
    "An experiment is a collection of runs. When you initialize a run in your training loop, you include the name of the experiment that the run belongs to. Experiment names must be unique within your AWS account."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-Requisites"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Latest SageMakerCore\n",
    "All SageMakerCore beta distributions will be released to a private s3 bucket. After being allowlisted, run the cells below to install the latest version of SageMakerCore from `s3://sagemaker-core-beta-artifacts/sagemaker_core-latest.tar.gz`\n",
    "\n",
    "Ensure you are using a kernel with python version >=3.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uninstall previous version of sagemaker-core and restart kernel\n",
    "!pip uninstall sagemaker-core -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make dist/ directory to hold the sagemaker-core beta distribution file\n",
    "!mkdir dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download and Install the latest version of sagemaker-core\n",
    "!aws s3 cp s3://sagemaker-core-beta-artifacts/sagemaker_core-latest.tar.gz dist/\n",
    "\n",
    "!pip install dist/sagemaker_core-latest.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the version of sagemaker-core\n",
    "!pip show -v sagemaker-core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Install Additional Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install additionall packages\n",
    "\n",
    "!pip install -U torch torchvision matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup\n",
    "\n",
    "Import required libraries and set logging and experiment configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "import os\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from sagemaker.core.helper.session_helper import Session\n",
    "from sagemaker.core.utils import get_textual_rich_logger\n",
    "\n",
    "logger = get_textual_rich_logger(__name__)\n",
    "session = Session()\n",
    "region = session.boto_region_name\n",
    "\n",
    "experiment_name = \"local-pyspark-experiment-example-\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n",
    "run_group_name = \"Default-Run-Group-\" + experiment_name\n",
    "run_name = \"local-experiment-run-\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the dataset\n",
    "\n",
    "Let's now use the torchvision library to download the MNIST dataset from tensorflow and apply a transformation on each image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# download the dataset\n",
    "# this will not only download data to ./mnist folder, but also load and transform (normalize) them\n",
    "datasets.MNIST.urls = [\n",
    "    f\"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/train-images-idx3-ubyte.gz\",\n",
    "    f\"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/train-labels-idx1-ubyte.gz\",\n",
    "    f\"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/t10k-images-idx3-ubyte.gz\",\n",
    "    f\"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/t10k-labels-idx1-ubyte.gz\",\n",
    "]\n",
    "\n",
    "train_set = datasets.MNIST(\n",
    "    \"mnist_data\",\n",
    "    train=True,\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    ),\n",
    "    download=True,\n",
    ")\n",
    "\n",
    "test_set = datasets.MNIST(\n",
    "    \"mnist_data\",\n",
    "    train=False,\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    ),\n",
    "    download=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_client = session.s3_client\n",
    "bucket_name = session.default_bucket()\n",
    "for f in os.listdir(train_set.raw_folder):\n",
    "    file_path = train_set.raw_folder + \"/\" + f\n",
    "    s3_client.upload_file(file_path, bucket_name, file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "View and example image from the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(train_set.data[2].numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create experiment and log dataset information\n",
    "\n",
    "Create an experiment run to track the model training. SageMaker Experiments is a great way to organize your data science work. You can create an experiment to organize all your model runs and analyse the different model metrics with the SageMaker Experiments UI.\n",
    "\n",
    "Here we create an experiment together with a trial and trial component for it. We also log all the downloaded files as inputs to our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import Experiment, Trial, TrialComponent\n",
    "from sagemaker.core.shapes import TrialComponentParameterValue, TrialComponentArtifact\n",
    "\n",
    "experiment = Experiment.create(experiment_name=experiment_name)\n",
    "trial = Trial.create(trial_name=run_group_name, experiment_name=experiment_name)\n",
    "\n",
    "trial_component_parameters = {\n",
    "    \"num_train_samples\": TrialComponentParameterValue(number_value=len(train_set.data)), \n",
    "    \"num_test_samples\": TrialComponentParameterValue(number_value=len(test_set.data)),\n",
    "}\n",
    "\n",
    "# Setting input dataset file path\n",
    "trial_component_input_artifacts = {}\n",
    "for f in os.listdir(train_set.raw_folder):\n",
    "    file_path = train_set.raw_folder + \"/\" + f\n",
    "    trial_component_input_artifacts[f] = TrialComponentArtifact(value=file_path)\n",
    "\n",
    "trial_component = TrialComponent.create(\n",
    "    trial_component_name=run_name,\n",
    "    parameters=trial_component_parameters,\n",
    "    input_artifacts=trial_component_input_artifacts,\n",
    ")\n",
    "trial_component.associate_trail(trial_name=run_group_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Checking the SageMaker Experiments UI, you can observe that a new Experiment was created with the run associated to it."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/experiment_created.png\" width=\"100%\" style=\"float: left;\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create model training functions\n",
    "\n",
    "Create an experiment run to track the model training. SageMaker Experiments is a great way to organize your data science work. You can create an experiment to organize all your model runs and analyse the different model metrics with the SageMaker Experiments UI.\n",
    "\n",
    "Here we create an experiment run and log parameters for the size of our training and test datasets. We also log all the downloaded files as inputs to our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, kernel_size, drop_out):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = torch.nn.Conv2d(1, hidden_channels, kernel_size=kernel_size)\n",
    "        self.conv2 = torch.nn.Conv2d(hidden_channels, 20, kernel_size=kernel_size)\n",
    "        self.conv2_drop = torch.nn.Dropout2d(p=drop_out)\n",
    "        self.fc1 = torch.nn.Linear(320, 50)\n",
    "        self.fc2 = torch.nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))\n",
    "        x = torch.nn.functional.relu(\n",
    "            torch.nn.functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2)\n",
    "        )\n",
    "        x = x.view(-1, 320)\n",
    "        x = torch.nn.functional.relu(self.fc1(x))\n",
    "        x = torch.nn.functional.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return torch.nn.functional.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def record_performance(model, data_loader, device, epoch, metrics, metric_type=\"Test\"):\n",
    "    \"\"\"\n",
    "    Record performance metric for every epoch, these metrics will be uploaded \n",
    "    after the training is finished\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in data_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            loss += torch.nn.functional.nll_loss(\n",
    "                output, target, reduction=\"sum\"\n",
    "            ).item()  # sum up batch loss\n",
    "            # get the index of the max log-probability\n",
    "            pred = output.max(1, keepdim=True)[1]\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    loss /= len(data_loader.dataset)\n",
    "    accuracy = 100.0 * correct / len(data_loader.dataset)\n",
    "    # record metrics\n",
    "    loss_metric = {\n",
    "        \"MetricName\": metric_type + \":loss\",\n",
    "        \"Value\": loss,\n",
    "        \"Step\": epoch,\n",
    "        \"Timestamp\": time.time(),\n",
    "    }\n",
    "    metrics.append(loss_metric)\n",
    "    accuracy_metric = {\n",
    "        \"MetricName\": metric_type + \":accuracy\",\n",
    "        \"Value\": accuracy,\n",
    "        \"Step\": epoch,\n",
    "        \"Timestamp\": time.time(),\n",
    "    }\n",
    "    metrics.append(accuracy_metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(\n",
    "    trial_component, train_set, test_set, data_dir=\"mnist_data\", optimizer=\"sgd\", epochs=10, hidden_channels=10\n",
    "):\n",
    "    \"\"\"\n",
    "    Function that trains the CNN classifier to identify the MNIST digits.\n",
    "    Args:\n",
    "        trial_component (sagemaker.core.resources.Run): SageMaker Experiment run object\n",
    "        train_set (torchvision.datasets.mnist.MNIST): train dataset\n",
    "        test_set (torchvision.datasets.mnist.MNIST): test dataset\n",
    "        data_dir (str): local directory where the MNIST datasource is stored\n",
    "        optimizer (str): the optimization algorthm to use for training your CNN\n",
    "                         available options are sgd and adam\n",
    "        epochs (int): number of complete pass of the training dataset through the algorithm\n",
    "        hidden_channels (int): number of hidden channels in your model\n",
    "    \"\"\"\n",
    "\n",
    "    # log the parameters of your model\n",
    "    training_parameters = {\n",
    "        \"device\": TrialComponentParameterValue(string_value=\"cpu\"),\n",
    "        \"data_dir\": TrialComponentParameterValue(string_value=data_dir),\n",
    "        \"optimizer\": TrialComponentParameterValue(string_value=optimizer),\n",
    "        \"epochs\": TrialComponentParameterValue(number_value=epochs),\n",
    "        \"hidden_channels\": TrialComponentParameterValue(number_value=hidden_channels),\n",
    "    }\n",
    "    trial_component.update(parameters=training_parameters)\n",
    "\n",
    "    # train the model on the CPU (no GPU)\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "    # set the seed for generating random numbers\n",
    "    torch.manual_seed(42)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=True)\n",
    "    model = Net(hidden_channels, kernel_size=5, drop_out=0.5).to(device)\n",
    "    model = torch.nn.DataParallel(model)\n",
    "    momentum = 0.5\n",
    "    lr = 0.01\n",
    "    if optimizer == \"sgd\":\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "        \n",
    "    metrics = []\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        logger.info(f\"Training Epoch: {epoch}\")\n",
    "        model.train()\n",
    "        for batch_idx, (data, target) in enumerate(train_loader, 1):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = torch.nn.functional.nll_loss(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        record_performance(model, train_loader, device, epoch, metrics, \"Train\")\n",
    "        record_performance(model, test_loader, device, epoch, metrics, \"Test\")\n",
    "            \n",
    "    trial_component.batch_put_metrics(MetricData=metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start the first run in your experiment\n",
    "\n",
    "You can use the `train_model` function with `trial_component` as parameter to start a run. Here we train the CNN with 5 hidden channels and ADAM as optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_model(\n",
    "    trial_component=trial_component,\n",
    "    train_set=train_set,\n",
    "    test_set=test_set,\n",
    "    epochs=5,\n",
    "    hidden_channels=2,\n",
    "    optimizer=\"adam\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the SageMaker Experiments UI, you can observe that the new model parameters are added to the run. The model training metrics are captured and can be used to plot graphs in Experiments -> select experiment -> Runs -> Analyze."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/experiment_run_parameters.png\" width=\"100%\" style=\"float: left;\" />\n",
    "<img src=\"images/experiment_run_metrics.png\" width=\"100%\" style=\"float: left;\" />\n",
    "<img src=\"images/experiment_run_analyze_plot.png\" width=\"100%\" style=\"float: left;\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run multiple experiments\n",
    "\n",
    "You can now create multiple runs of your experiment using the functions created before"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the list of parameters to train the model with\n",
    "num_hidden_channel_param = [5, 10]\n",
    "optimizer_param = [\"adam\", \"sgd\"]\n",
    "run_id = 0\n",
    "# train the model using SageMaker Experiments to track the model parameters,\n",
    "# metrics and performance\n",
    "for i, num_hidden_channel in enumerate(num_hidden_channel_param):\n",
    "    for k, optimizer in enumerate(optimizer_param):\n",
    "        run_id += 1\n",
    "        run_name = \"local-experiment-run-\" + str(run_id) + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n",
    "        \n",
    "        # Defining an experiment run for each model training run\n",
    "        trial_component = TrialComponent.create(\n",
    "            trial_component_name=run_name,\n",
    "            parameters=trial_component_parameters,\n",
    "            input_artifacts=trial_component_input_artifacts,\n",
    "        )\n",
    "        trial_component.associate_trail(trial_name=run_group_name)\n",
    "        \n",
    "        logger.info(\n",
    "            f\"{run_name}: Training model with {num_hidden_channel} hidden channels and {optimizer} as optimizer\"\n",
    "        )\n",
    "        train_model(\n",
    "            trial_component=trial_component,\n",
    "            train_set=train_set,\n",
    "            test_set=test_set,\n",
    "            epochs=5,\n",
    "            hidden_channels=num_hidden_channel,\n",
    "            optimizer=optimizer,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "<img src=\"images/experiment_runs_comparison.png\" width=\"100%\" style=\"float: left;\" />"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
